from __future__ import annotations

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import warnings
import logging
import copy
import torch
from ..base import BaseModel
from .prompt import ThymePromptMixin
from .sandbox import execute_code_in_sandbox
from .utils import (generate_prompt_final_qa, generate_prompt_simple_qa, SPECIAL_STRING_LIST, 
                    REASONING_SYS_PROMPT, SIMPLE_SYS_PROMPT
                    )
import re
from transformers.cache_utils import (
    DynamicCache
)

def ensure_image_url(image: str) -> str:
    prefixes = ['http://', 'https://', 'file://', 'data:image;']
    if any(image.startswith(prefix) for prefix in prefixes):
        return image
    if os.path.exists(image):
        return 'file://' + image
    raise ValueError(f'Invalid image: {image}')

def ensure_video_url(video: str) -> str:
    prefixes = ['http://', 'https://', 'file://', 'data:video;']
    if any(video.startswith(prefix) for prefix in prefixes):
        return video
    if os.path.exists(video):
        return 'file://' + video
    raise ValueError(f'Invalid video: {video}')


class Thyme(ThymePromptMixin, BaseModel):
    INSTALL_REQ = False
    INTERLEAVE = True
    VIDEO_LLM = True

    def __init__(
        self,
        model_path: str,
        min_pixels: int | None = None,
        max_pixels: int | None = None,
        max_new_tokens=2048,
        top_p=0.001,
        top_k=1,
        temperature=0.01,
        repetition_penalty=1.0,
        max_iterations = 5,     # pandayin: rounds of intermediate steps before reaching final answer.
        max_retry =5,           # pandayin: max retry before reaching a valid answer.
        use_custom_prompt: bool = True,
        system_prompt: str | None = "You are a helpful assistant.",
        post_process: bool = True,  # if True, will try to only extract stuff wrapped in <answer> & </answer>.
        verbose: bool = False,
        **kwargs,
    ):
        super().__init__(use_custom_prompt=use_custom_prompt)
        self.min_pixels = min_pixels
        self.max_pixels = max_pixels
        self.top_p = top_p
        self.top_k = top_k
        self.temperature = temperature
        self.system_prompt = system_prompt
        self.max_iterations = max_iterations
        self.max_retry = max_retry
        self.verbose = verbose
        self.post_process = post_process
        self.fps = 2.0
        self.nframe = 64
        self.FRAME_FACTOR = 2
        assert model_path is not None
        self.model_path = model_path
        MODEL_CLS = None
        
        from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
        MODEL_CLS = Qwen2_5_VLForConditionalGeneration
        self.processor = AutoProcessor.from_pretrained(model_path)
            
        self.generate_kwargs = dict(
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            top_k=top_k,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
            stop_strings=SPECIAL_STRING_LIST,
            eos_token_id=self.processor.tokenizer.eos_token_id,
            tokenizer=self.processor.tokenizer
        )
        
        self.model = MODEL_CLS.from_pretrained(
            model_path, torch_dtype='auto', device_map='auto', attn_implementation='sdpa'
        )
        self.model.eval()

        torch.cuda.empty_cache()

    def _extract_image_path(self, contents: list[dict[str, str]]):
        user_image_path = ""
        content_history = copy.deepcopy(contents)
        for rou in content_history:
            if rou['type'] != 'image': continue
            user_image_path = rou['value']
            break
        return user_image_path
    
    def _prepare_content(self, inputs: list[dict[str, str]], dataset: str | None = None) -> list[dict[str, str]]:
        """
        inputs list[dict[str, str]], each dict has keys: ['type', 'value']
        """
        user_image_path = self._extract_image_path(inputs)
        content = []
        for s in inputs:
            if s['type'] == 'image':
                item = {'type': 'image', 'image': ensure_image_url(s['value'])}
                if dataset == 'OCRBench':
                    item['min_pixels'] = 10 * 10 * 28 * 28
                    warnings.warn(f"OCRBench dataset uses custom min_pixels={item['min_pixels']}")
                    if self.max_pixels is not None:
                        item['max_pixels'] = self.max_pixels
                else:
                    if self.min_pixels is not None:
                        item['min_pixels'] = self.min_pixels
                    if self.max_pixels is not None:
                        item['max_pixels'] = self.max_pixels
            elif s['type'] == 'video':
                item = {
                    'type': 'video',
                    'video': ensure_video_url(s['value']),
                    'min_pixels': self.min_pixels,
                    'max_pixels': self.max_pixels
                }
                if self.fps is not None:
                    item['fps'] = self.fps
                elif self.nframe is not None:
                    import cv2
                    video = cv2.VideoCapture(s['value'])
                    frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
                    video.release()
                    if frame_count < self.nframe:
                        new_frame_count = frame_count // self.FRAME_FACTOR * self.FRAME_FACTOR
                        print(f"use {new_frame_count} for {s['value']}")
                        item['nframes'] = new_frame_count
                    else:
                        item['nframes'] = self.nframe
            # pandayin: wrap user query with customized prompt.
            # TODO: consider support for multi-images or videos...
            elif s['type'] == 'text':
                item = {'type': 'text', 'text': generate_prompt_final_qa(s['value'], user_image_path)}
            else:
                raise ValueError(f"Invalid message type: {s['type']}, {s}")
            content.append(item)
        return content

    def _prepare_content_simple(self, inputs: list[dict[str, str]], dataset: str | None = None) -> list[dict[str, str]]:
        """
        inputs list[dict[str, str]], each dict has keys: ['type', 'value']
        """
        content = []
        for s in inputs:
            if s['type'] == 'image':
                item = {'type': 'image', 'image': ensure_image_url(s['value'])}
                if dataset == 'OCRBench':
                    item['min_pixels'] = 10 * 10 * 28 * 28
                    warnings.warn(f"OCRBench dataset uses custom min_pixels={item['min_pixels']}")
                    if self.max_pixels is not None:
                        item['max_pixels'] = self.max_pixels
                else:
                    if self.min_pixels is not None:
                        item['min_pixels'] = self.min_pixels
                    if self.max_pixels is not None:
                        item['max_pixels'] = self.max_pixels
            elif s['type'] == 'video':
                item = {
                    'type': 'video',
                    'video': ensure_video_url(s['value']),
                    'min_pixels': self.min_pixels,
                    'max_pixels': self.max_pixels
                }
                if self.fps is not None:
                    item['fps'] = self.fps
                elif self.nframe is not None:
                    import cv2
                    video = cv2.VideoCapture(s['value'])
                    frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
                    video.release()
                    if frame_count < self.nframe:
                        new_frame_count = frame_count // self.FRAME_FACTOR * self.FRAME_FACTOR
                        print(f"use {new_frame_count} for {s['value']}")
                        item['nframes'] = new_frame_count
                    else:
                        item['nframes'] = self.nframe
            # pandayin: wrap user query with customized prompt.
            # TODO: consider support for multi-images or videos...
            elif s['type'] == 'text':
                item = {'type': 'text', 'text': generate_prompt_simple_qa(s['value'])}
            else:
                raise ValueError(f"Invalid message type: {s['type']}, {s}")
            content.append(item)
        return content
    
    def _extract_box_answer(self, response):
        resp = response.split('\\boxed{')[-1]
        lt = len(resp)
        counter, end = 1, None
        for i in range(lt):
            if resp[i] == '{':
                counter += 1
            elif resp[i] == '}':
                counter -= 1
            if counter == 0:
                end = i
                break
            elif i == lt - 1:
                end = lt
                break
        if end is not None:
            response = resp[:end]
        return response
    
    def _remove_unpickable_values(self, dictionary):
        import pickle

        def is_pickable(obj):
            try:
                pickle.dumps(obj)
                return True
            except (pickle.PicklingError, TypeError, AttributeError):
                return False
    
        keys_to_remove = []
        for key, value in dictionary.items():
            if isinstance(value, dict):
                self._remove_unpickable_values(value)
            elif not is_pickable(value):
                keys_to_remove.append(key)
        for key in keys_to_remove:
            del dictionary[key]
        return dictionary
        

    def generate_inner_transformers(self, message, dataset=None):
        try:
            from qwen_vl_utils import process_vision_info
        except Exception as err:
            logging.critical("qwen_vl_utils not found, please install it via 'pip install qwen-vl-utils'")  # noqa: E501
            raise err
        
        # pandayin: get image path from the input sample.
        user_image_path = self._extract_image_path(message)
        
        messages = []
        messages.append({'role': 'system', 'content': REASONING_SYS_PROMPT})
        messages.append({'role': 'user', 'content': self._prepare_content(message, dataset=dataset)})
        
        
        if self.verbose:
            print(f'\033[31m{messages}\033[0m')

        #   -------   outer loop. retry multiple times if fail to reach a valid answer.   -------
        retry_generations = self.max_retry
        has_valid_answer = False
        while (retry_generations > 0) and (not has_valid_answer):
            # pandayin: main logic/ work flow for generation. 
            # The gist is to pause at special tokens (</code> & </answer>) and maybe perform code execution.
            conversation_history = copy.deepcopy(messages)
            
            # For each generation, we initialize a KV-Cache to speed up inference.
            kv_cache = DynamicCache()
            # Maintain a dictionary to save context (local & global vars.) for code execution.
            previous_execution_context = {}
            if self.verbose:
                print(f'\033[32m\n--- Generation {self.max_retry - retry_generations + 1} ---\033[0m')
                    
            #   -------   inner loop. generate multiple steps until reaching a valid answer.   -------
            retry_iterations = self.max_iterations
            # We assume each answer round is limited to a few code (usually 1) execution.
            while retry_iterations > 0:
                retry_iterations -= 1
                generated_content = []
                if self.verbose:
                    print(f'\033[32m\n--- Iteration {self.max_iterations - retry_iterations} ---\033[0m')
                
                
                text = self.processor.apply_chat_template([conversation_history], 
                                                          tokenize=False, 
                                                          add_generation_prompt=(retry_iterations==self.max_iterations-1)
                                                          )
                
                if retry_iterations != self.max_iterations-1:
                    if text[0].endswith("<|im_end|>\n"):
                        text[0] = text[0][:-len("<|im_end|>\n")]
                images, videos = process_vision_info([conversation_history])
                inputs = self.processor(text=text, images=images, videos=videos, padding=True, return_tensors='pt')
                inputs = inputs.to('cuda')
                
                # just in case this iteration is invalid, we need to roll back, thus making a backup.
                last_kv_cache = copy.deepcopy(kv_cache)
                # bkup context. roll back when we fail to execute the generated code.
                last_execution_context = copy.deepcopy(self._remove_unpickable_values(previous_execution_context))
                generated_ids = self.model.generate(
                    **inputs,
                    **self.generate_kwargs,
                    past_key_values=kv_cache
                )
                generated_ids = [
                    output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
                ]
                out = self.processor.tokenizer.batch_decode(
                    generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
                )
                generated_text_segment = out[0]
                
                # Case 1: directly give answer
                if "</answer>" in generated_text_segment:
                    generated_content.append(
                        {"type": "text", "text": generated_text_segment},
                    )
                
                # Case 2: reach code generation.
                # parse current result. Two cases: reach </code> or reach </answer>
                code_regex = re.compile(r'<code>\s*(?:```\s*)?(?:python\s*)?([\s\S]*?)\s*(?:```\s*)?</code>', re.IGNORECASE)
                
                code_match = code_regex.search(generated_text_segment)
                
                # execute code and return result.
                if code_match:
                    code_to_execute = code_match.group(1).strip()
                    if self.verbose:
                        print(f"\033[31m--- Found Code Block ---\n{generated_text_segment}\n-------------------------\033[0m")

                    processed_img_paths, captured_stdout, error_msg, current_execution_context = execute_code_in_sandbox(
                        code_to_execute, user_image_path, #temp_output_dir=intermediate_images_save_dir,
                        previous_execution_context=previous_execution_context
                    )
                    previous_execution_context = current_execution_context
                    if not processed_img_paths:
                        kv_cache = last_kv_cache    # deemed as unsuccessful iteration. roll back status.
                        previous_execution_context = last_execution_context
                        print(f'{error_msg}')
                        continue      
                       
                    has_valid_images = False
                    generated_content += [
                                        {"type": "text", "text": generated_text_segment},
                                        {"type": "text", "text": "<sandbox_output>"}
                                    ]
                    first_path = processed_img_paths[0]
                    if os.path.exists(first_path):
                        # Iterate through each path in the list
                        for img_path in processed_img_paths:
                            if os.path.exists(img_path):
                                if not has_valid_images: # Add text segments only once per sandbox output block
                                    has_valid_images = True
                                generated_content.append({"type": "image", "image": img_path})                            
                    else:
                        generated_content.append({"type": "text", "text": first_path})

                    if has_valid_images or not os.path.exists(first_path):
                        generated_content.append({"type": "text", "text": "</sandbox_output>"})
                    else:
                        # pandayin: a failed code execution/generation doesn't count as an intermedia step.
                        print('skip this generation due to error and adapt the temperature')
                        self.generate_kwargs['temperature'] = 1.0
                        continue
                else:
                    # wo code. wo </answer>, assume repetition generated, break.
                    if "</answer>" not in generated_text_segment:
                        print('wo code. wo </answer>')
                        print(generated_text_segment)
                        self.generate_kwargs['temperature'] = 1.0
                        break

                # Update conversation_history with the latest generated segment
                # If the last message was 'user', start a new 'assistant' message
                if conversation_history[-1]["role"] == "user":
                    conversation_history.append({"role": "assistant", "content": generated_content})
                # If the last message was 'assistant', append to its last text content item
                elif conversation_history[-1]["role"] == "assistant":
                    conversation_history[-1]["content"] += generated_content
                
                # --- Check for final answer tag if no code was processed in this segment ---
                if "</answer>" in generated_text_segment:
                    has_valid_answer = True
                    print("\033[32m--- Final answer tag found. ---\033[0m")
                    break
                
                # If the model produced an EOS token and no code/answer, it might be finished
                if generated_ids[0][-1] == self.processor.tokenizer.eos_token_id:
                    if self.verbose:
                        print("\033[32m--- Model generated EOS and no further actions (code/answer). Assuming completion. ---\033[0m")
                    break
            
            # End of a generation. Maybe successfully find a valid answer, or start a new generation.
            if self.verbose:
                if has_valid_answer:
                    print(f"\033[32m\n--- End of processing (max iterations: {self.max_iterations}, actual: {self.max_iterations - retry_iterations + 1}) ---\033[0m")
                    break
                else:
                    print(f"\033[32m\n --- Fail to find a valid answer. (max retrys: {self.max_retry}, actual: {self.max_retry - retry_generations + 1})---\033[0m")
            
            retry_generations -= 1 
            # pandayin: TODO: maybe adjust/reset generation_kwargs here? So more explorations could be done to find a valid answer. 
            print('Fail to find a valid answer and adapt the temperature')
            self.generate_kwargs['temperature'] = 1.0

        
        # reset generation hyper-param.
        self.generate_kwargs['temperature'] = self.temperature #0.01
        
        # pandayin: If we still fail after max_try generations, try a simple prompt.
        if not has_valid_answer:
            print(f"\033[32m\n --- Fail to find a valid answer after {self.max_retry} retrys. Falling back to simple prompt.---\033[0m")
            del self.generate_kwargs['stop_strings']
            
            messages = []
            if self.system_prompt is not None:
                messages.append({'role': 'system', 'content': SIMPLE_SYS_PROMPT})
            messages.append({'role': 'user', 'content': self._prepare_content_simple(message, dataset=dataset)})
            conversation_history = copy.deepcopy(messages)
            text = self.processor.apply_chat_template([conversation_history], 
                                                          tokenize=False, 
                                                          add_generation_prompt=True
                                                          )
            
            images, videos = process_vision_info([conversation_history])
            inputs = self.processor(text=text, images=images, videos=videos, padding=True, return_tensors='pt')
            inputs = inputs.to('cuda')
            generated_ids = self.model.generate(
                    **inputs,
                    **self.generate_kwargs,
                )
            generated_ids = [
                output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
            ]
            out = self.processor.tokenizer.batch_decode(
                generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )
            generated_text_segment = out[0]
            
            self.generate_kwargs['stop_strings'] = SPECIAL_STRING_LIST
            
            # to align with the following processing procedure. wrap a <answer> bracket.
            answer_match = re.search(r"<answer>(.*?)</answer>", generated_text_segment, re.DOTALL)
            if not answer_match:
                generated_text_segment = "<answer>" + generated_text_segment + "</answer>"
            conversation_history.append({"role": "assistant", "content": [{"type": "text", "text": generated_text_segment}]})
            
            
        
        final_assistant_response = ""
        for msg in reversed(conversation_history):
            if msg['role'] != 'assistant': continue
            current_content_str = ""
            for item in msg['content']:
                if item['type'] == 'text':
                    current_content_str += item['text']
            final_assistant_response = current_content_str # Get the last full response from assistant
            break
        
          
        if self.post_process:
            print(f"\033[31m--- Final response ---\n{final_assistant_response}\n-------------------------\033[0m")
            # Extract content within <answer> tags from the final assistant response
            answer_match = re.search(r"<answer>(.*?)</answer>", final_assistant_response, re.DOTALL)
            if answer_match:
                final_answer = answer_match.group(1).strip()
            else:
                final_answer = "No answer tag found in the final output."

            # Sometimes the answer is still wrapped in \boxed{}, keeping the behaviour of Qwen2.5-VL.
            # We extract the answer within this.
            match = re.search(r'\\boxed\{(.*?)\}', final_answer)
            if match:
                final_answer = self._extract_box_answer(final_answer)
                #final_answer = match.group(1).strip()
            

            if self.verbose:
                print(f'\033[32m{final_answer}\033[0m')
            return final_answer
        else:
            return final_assistant_response

    def generate_inner(self, message, dataset=None):
        return self.generate_inner_transformers(message, dataset=dataset)
